/**
 * 
 */
package edu.umd.clip.lm.tools;

import java.io.*;
import java.nio.charset.Charset;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.regex.*;

import edu.berkeley.nlp.util.*;
import edu.umd.clip.jobs.JobManager;
import edu.umd.clip.lm.util.IO;
import edu.umd.clip.lm.util.algo.FastMDI;
import edu.umd.clip.lm.util.tree.AnyaryTree;
import edu.umd.clip.lm.util.tree.AnyaryTreeFinder;

/**
 * @author Denis Filimonov <den@cs.umd.edu>
 *
 */
public class ClusterWordsFastMDI {

	public static class Options {
        @Option(name = "-input", required = true, usage = "Training text file")
		public String input;
        @Option(name = "-output", required = false, usage = "Class output file (Default: stdout)")
		public String output;
        @Option(name = "-config", usage = "XML config file")
		public String config;
        @Option(name = "-jobs", usage = "number of concurrent jobs (default: 1)")
        public int jobs = 1;
        @Option(name = "-no-cluster", required = false, usage = "List of words not to cluster")
        public String noCluster;
        @Option(name = "-levels", required = true, usage = "Number of clusters is 2^levels")
        public int levels;
	}
	
	private static final String SENT_START = "<s>";
	private static final String SENT_END = "</s>";
	
	private static final Pattern wordSplitRE = Pattern.compile("\\s+");
	
	private static class Param {
		String word;
		boolean left;
		
		public Param(String word, boolean left) {
			this.word = word;
			this.left = left;
		}

		@Override
		public int hashCode() {
			final int prime = 31;
			int result = 1;
			result = prime * result + (left ? 1231 : 1237);
			result = prime * result + ((word == null) ? 0 : word.hashCode());
			return result;
		}
		@Override
		public boolean equals(Object obj) {
			if (this == obj)
				return true;
			if (obj == null)
				return false;
			if (!(obj instanceof Param))
				return false;
			Param other = (Param) obj;
			if (left != other.left)
				return false;
			if (word == null) {
				if (other.word != null)
					return false;
			} else if (!word.equals(other.word))
				return false;
			return true;
		}
	}
	/**
	 * @param args
	 * @throws IOException 
	 */
	public static void main(String[] args) throws IOException {
        OptionParser optParser = new OptionParser(Options.class);
        final Options opts = (Options) optParser.parse(args, true);

        JobManager.initialize(opts.jobs);
		Thread thread = new Thread(JobManager.getInstance(), "Job Manager");
		thread.setDaemon(true);
		thread.start();

        HashMap<String,String> noClusterDict = new HashMap<String,String>();
		ConcurrentHashMap<String,String> wordVocabHash = new ConcurrentHashMap<String,String>(1000000);
        
		wordVocabHash.put(SENT_START, SENT_START);
		wordVocabHash.put(SENT_END, SENT_END);
	
		final ArrayList<String> wordVocab;
		ArrayList<Param> otherVocab;
		
		HashMap<String,Integer> word2idx;
		HashMap<Param,Integer> other2idx;

		FastMDI<String,Param> algo;
		long[] unigramCounts;

		BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(opts.output == null ? System.out : new FileOutputStream(opts.output), 
				Charset.forName("UTF-8")));

		double totalCounts = 0.0;

		if (opts.noCluster != null) {
			InputStreamReader reader = new InputStreamReader(new FileInputStream(opts.noCluster), Charset.forName("UTF-8"));
			BufferedReader inputData = new BufferedReader(reader);
			for(String line = inputData.readLine(); line != null; line = inputData.readLine()) {
				line = line.trim();
				noClusterDict.put(line, line);
				wordVocabHash.put(line, line);
			}
			inputData.close();
		}
		noClusterDict.put(SENT_START, SENT_START);
		noClusterDict.put(SENT_END, SENT_END);

		BufferedReader inputData = IO.getReader(IO.getInputStream(opts.input));

		for(String line = inputData.readLine(); line != null; line = inputData.readLine()) {
			line = line.trim();
			if (line.isEmpty()) continue;

			String realWords[] = wordSplitRE.split(line);
			for(String word : realWords) {
				wordVocabHash.putIfAbsent(word, word);
			}
		}
		inputData.close();

		if (noClusterDict.size() >= wordVocabHash.size()) {
			System.err.println("noClusterDict.size() >= dictionary.size()");
			return;
		}

		wordVocab = new ArrayList<String>(wordVocabHash.size() - noClusterDict.size() + 2);
		otherVocab = new ArrayList<Param>(wordVocabHash.size() * 2);

		word2idx = new HashMap<String,Integer>(wordVocab.size());
		other2idx = new HashMap<Param,Integer>(otherVocab.size());

		for(String word : wordVocabHash.keySet()) {
			Param right = new Param(word, false);
			Param left = new Param(word, true);

			otherVocab.add(right);
			other2idx.put(right, other2idx.size());

			otherVocab.add(left);
			other2idx.put(left, other2idx.size());


			if (!noClusterDict.containsKey(word)) {
				wordVocab.add(word);
				word2idx.put(word, word2idx.size());
			}
		}

		algo = new FastMDI<String,Param>(wordVocab.toArray(new String[0]), otherVocab.toArray(new Param[0]));

		unigramCounts = new long[wordVocab.size()];

		inputData = IO.getReader(IO.getInputStream(opts.input));
		for(String line = inputData.readLine(); line != null; line = inputData.readLine()) {
			line = line.trim();
			if (line.isEmpty()) continue;

			String sentence[] = wordSplitRE.split(line);

			Param param = new Param(null, false);

			for(int i=0; i<sentence.length; ++i) {
				if (noClusterDict.containsKey(sentence[i])) continue;

				int wordIdx = word2idx.get(sentence[i]);
				unigramCounts[wordIdx] += 1;
				totalCounts += 1;

				String prevWord = i==0 ? SENT_START : sentence[i-1];
				String nextWord = i==sentence.length-1 ? SENT_END : sentence[i+1];

				param.word = prevWord;
				param.left = true;

				int leftOtherIdx = other2idx.get(param);

				param.word = nextWord;
				param.left = false;

				int rightOtherIdx = other2idx.get(param);

				algo.addProb(wordIdx, leftOtherIdx, 1.0);
				algo.addProb(wordIdx, rightOtherIdx, 1.0);

				//System.out.println("size = " + Integer.toString(counts.size()));
			}
		}
		final Map<int[], AnyaryTree<String>> treeNodes = new IdentityHashMap<int[], AnyaryTree<String>>();
		FastMDI.Notifier MDINotifier = new FastMDI.Notifier() {

			/* (non-Javadoc)
			 * @see edu.umd.clip.lm.util.algo.MDIClusterNotifier#notify(java.util.Collection, java.util.Collection, java.util.Collection)
			 */
			public synchronized boolean notify(int[] oldCluster, int[] cluster1, int[] cluster2) {
				AnyaryTree<String> node = treeNodes.get(oldCluster);
				String label1 = null;
				if (cluster1.length == 1) {
					label1 = wordVocab.get(cluster1[0]);
				}
				AnyaryTree<String> subnode1 = new AnyaryTree<String>(label1);
				node.addChild(subnode1);

				String label2 = null;
				if (cluster2.length == 1) {
					label2 = wordVocab.get(cluster2[0]);
				}
				AnyaryTree<String> subnode2 = new AnyaryTree<String>(label2);
				node.addChild(subnode2);

				treeNodes.put(cluster1, subnode1);
				treeNodes.put(cluster2, subnode2);

				// compute the level
				int level = 1;
				{
					AnyaryTree<String> n = node;
					while(n != null) {
						++level;
						n = n.getParent();
					}
				}

				System.err.printf("[%d] split %d words into %d and %d\n", level, oldCluster.length, cluster1.length, cluster2.length);

				return level <= opts.levels;
			}
		};

		AnyaryTree<String> theRoot = null;
		int[] clusterVocab = new int[wordVocab.size()];
		for(int i=0; i<clusterVocab.length; ++i) {
			clusterVocab[i] = i;
		}

		AnyaryTree<String> treeRoot = new AnyaryTree<String>(null);
		treeNodes.put(clusterVocab, treeRoot);
		if (theRoot == null) {
			theRoot = treeRoot;
		}

		//MDI<String,String> algo = new MDI<String,String>(clusterVocab, dictionary.keySet());
		algo.setNotifier(MDINotifier);

		algo.partition(clusterVocab);

		AnyaryTreeFinder<String> finder = new AnyaryTreeFinder<String>() {

			@Override
			public boolean apply(AnyaryTree<String> node) {
				AnyaryTree<String> parent = node.getParent();
				if (parent == null) {
					node.setPayload("CLASS-");
				} else {
					String label = parent.getPayload();
					if (parent.getChildren().iterator().next() == node) {
						label += '0';
					} else {
						label += '1';
					}
					node.setPayload(label);
				}
				return false;
			}
		};
		treeRoot.searchPreorder(finder);
		
		for(Map.Entry<int[], AnyaryTree<String>> e : treeNodes.entrySet()) {
			AnyaryTree<String> node = e.getValue();
			if (node.isLeaf()) {
				int[] wordIdxs = e.getKey();
				String clusterName = node.getPayload();

				long totalClusterCount = 0;
				for(int wordIdx : wordIdxs) {
					totalClusterCount += unigramCounts[wordIdx];
				}
				for(int wordIdx : wordIdxs) {
					double prob = (double) unigramCounts[wordIdx] / totalClusterCount;

					writer.append(String.format("%s %f %s\n", clusterName, prob, wordVocab.get(wordIdx)));
				}
			}
		}

		writer.close();
	}

}
